{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validation notebook for DUC models\n",
    "\n",
    "## Overview\n",
    "Use this notebook to verify the accuracy of a trained DUC model in ONNX format on the validation set of cityscapes dataset.\n",
    "\n",
    "## Models supported\n",
    "* ResNet101_DUC_HDC\n",
    "\n",
    "## Prerequisites\n",
    "The following packages need to be installed before proceeding:\n",
    "* Protobuf compiler - `sudo apt-get install protobuf-compiler libprotoc-dev` (required for ONNX. This will work for any linux system. For detailed installation guidelines head over to [ONNX documentation](https://github.com/onnx/onnx#installation))\n",
    "* ONNX - `pip install onnx`\n",
    "* MXNet - `pip install mxnet-cu90mkl --pre -U` (tested on this version GPU, can use other versions. `--pre` indicates a pre build of MXNet which is required here for ONNX version compatibility. `-U` uninstalls any existing MXNet version allowing for a clean install)\n",
    "* numpy - `pip install numpy`\n",
    "* OpenCV - `pip install opencv-python`\n",
    "* PIL - `pip install pillow`\n",
    "\n",
    "Also the following scripts (included in the repo) must be present in the same folder as this notebook:\n",
    "* `cityscapes_loader.py` (load and prepare validation images and labels)\n",
    "* `utils.py` (helper script used by `cityscapes_loader.py`)\n",
    "* `cityscapes_labels.py` (contains segmentation category labels)\n",
    "\n",
    "The validation set of Cityscapes must be prepared before proceeding. Follow guidelines in the [dataset](README.md/#dset) section.\n",
    "\n",
    "In order to do inference with a python script:\n",
    "* Generate the script : In Jupyter Notebook browser, go to File -> Download as -> Python (.py)\n",
    "* Run the script: `python duc-validation.py`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import dependencies\n",
    "Verify that all dependencies are installed using the cell below. Continue if no errors encountered, warnings can be ignored."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from __future__ import print_function\n",
    "import mxnet as mx\n",
    "import numpy as np\n",
    "import glob\n",
    "import os\n",
    "from mxnet.contrib.onnx import import_model\n",
    "from cityscapes_loader import CityLoader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set paths and parameters\n",
    "Prepare the validation set of cityscapes according to the guidelines provided in [dataset](README.md/#dset) section. Set paths `data_dir` and `label_dir` accordingly."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Determine and set context\n",
    "if len(mx.test_utils.list_gpus())==0:\n",
    "    ctx = mx.cpu()\n",
    "else:\n",
    "    ctx = mx.gpu(0)\n",
    "\n",
    "# Path to validation data\n",
    "data_dir = '/home/ubuntu/TuSimple-DUC/dataset/leftImg8bit/val'\n",
    "# Path to validation labels\n",
    "label_dir = '/home/ubuntu/TuSimple-DUC/dataset/gtFine/val'\n",
    "# Set batch size\n",
    "batch_size = 16"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download ONNX model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "mx.test_utils.download('https://s3.amazonaws.com/onnx-model-zoo/duc/ResNet101_DUC_HDC.onnx')\n",
    "# Path to ONNX model\n",
    "model_path = 'ResNet101_DUC_HDC.onnx'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare dataset list\n",
    "Prepare validation images list (val.lst) containing image and label paths along with cropping metrics."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "index = 0\n",
    "val_lst = []\n",
    "# images\n",
    "all_images = glob.glob(os.path.join(data_dir, '*/*.png'))\n",
    "all_images.sort()\n",
    "for p in all_images:\n",
    "    l = p.replace(data_dir, label_dir).replace('leftImg8bit', 'gtFine_labelIds')\n",
    "    if os.path.isfile(l):\n",
    "        index += 1\n",
    "        for i in range(1, 8):\n",
    "            val_lst.append([str(index), p, l, \"512\", str(256 * i)])\n",
    "\n",
    "val_out = open('val.lst', \"w\")\n",
    "for line in val_lst:\n",
    "    print('\\t'.join(line),file=val_out)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Define evaluation metric\n",
    "`class IoUMetric` : Defines mean Intersection Over Union (mIOU) custom evaluation metric.\n",
    "\n",
    "`check_label_shapes` : Checks the shape of target labels and network output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def check_label_shapes(labels, preds, shape=0):\n",
    "    if shape == 0:\n",
    "        label_shape, pred_shape = len(labels), len(preds)\n",
    "    else:\n",
    "        label_shape, pred_shape = labels.shape, preds.shape\n",
    "\n",
    "    if label_shape != pred_shape:\n",
    "        raise ValueError(\"Shape of labels {} does not match shape of \"\n",
    "                         \"predictions {}\".format(label_shape, pred_shape))\n",
    "\n",
    "class IoUMetric(mx.metric.EvalMetric):\n",
    "    def __init__(self, ignore_label, label_num, name='IoU'):\n",
    "        self._ignore_label = ignore_label\n",
    "        self._label_num = label_num\n",
    "        super(IoUMetric, self).__init__(name=name)\n",
    "\n",
    "    def reset(self):\n",
    "        self._tp = [0.0] * self._label_num\n",
    "        self._denom = [0.0] * self._label_num\n",
    "\n",
    "    def update(self, labels, preds):\n",
    "        check_label_shapes(labels, preds)\n",
    "        for i in range(len(labels)):\n",
    "            pred_label = mx.ndarray.argmax_channel(preds[i]).asnumpy().astype('int32')\n",
    "            label = labels[i].asnumpy().astype('int32')\n",
    "\n",
    "            check_label_shapes(label, pred_label)\n",
    "\n",
    "            iou = 0\n",
    "            eps = 1e-6\n",
    "            for j in range(self._label_num):\n",
    "                pred_cur = (pred_label.flat == j)\n",
    "                gt_cur = (label.flat == j)\n",
    "                tp = np.logical_and(pred_cur, gt_cur).sum()\n",
    "                denom = np.logical_or(pred_cur, gt_cur).sum() - np.logical_and(pred_cur, label.flat == self._ignore_label).sum()\n",
    "                assert tp <= denom\n",
    "                self._tp[j] += tp\n",
    "                self._denom[j] += denom\n",
    "                iou += self._tp[j] / (self._denom[j] + eps)\n",
    "            iou /= self._label_num\n",
    "            self.sum_metric = iou\n",
    "            self.num_inst = 1\n",
    "\n",
    "            \n",
    "# Create evaluation metric\n",
    "met = IoUMetric(ignore_label=255, label_num=19, name=\"IoU\")\n",
    "metric = mx.metric.create(met)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configure data loader\n",
    "An object of CityLoader class (inherited from mx.io.DataIter) is instantiated for loading and precessing the validation data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "loader = CityLoader\n",
    "val_args = {\n",
    "    'data_path'             : data_dir,\n",
    "    'label_path'            : label_dir,\n",
    "    'rgb_mean'              : (122.675, 116.669, 104.008),\n",
    "    'batch_size'            : batch_size,\n",
    "    'scale_factors'         : [1],\n",
    "    'data_name'             : 'data',\n",
    "    'label_name'            : 'seg_loss_label',\n",
    "    'data_shape'            : [tuple(list([batch_size, 3, 800, 800]))],\n",
    "    'label_shape'           : [tuple([batch_size, (160000)])],\n",
    "    'use_random_crop'       : False,\n",
    "    'use_mirror'            : False,\n",
    "    'ds_rate'               : 8,\n",
    "    'convert_label'         : True,\n",
    "    'multi_thread'          : False,\n",
    "    'cell_width'            : 2,\n",
    "    'random_bound'          : [120,120],\n",
    "}\n",
    "val_dataloader = loader('val.lst', val_args)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load ONNX model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# import ONNX model into MXNet symbols and params\n",
    "sym,arg,aux = import_model(model_path)\n",
    "# define network module\n",
    "mod = mx.mod.Module(symbol=sym, data_names=['data'], context=ctx, label_names=None)\n",
    "# bind parameters to the network\n",
    "mod.bind(for_training=False, data_shapes=[('data', (batch_size, 3, 800, 800))], label_shapes=mod._label_shapes)\n",
    "mod.set_params(arg_params=arg, aux_params=aux,allow_missing=True, allow_extra=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Compute evaluations\n",
    "Perform forward pass over each batch and generate evaluations"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 / 218 batches done\n",
      "10 / 218 batches done\n",
      "20 / 218 batches done\n",
      "30 / 218 batches done\n",
      "40 / 218 batches done\n",
      "50 / 218 batches done\n",
      "60 / 218 batches done\n",
      "70 / 218 batches done\n",
      "80 / 218 batches done\n",
      "90 / 218 batches done\n",
      "100 / 218 batches done\n",
      "110 / 218 batches done\n",
      "120 / 218 batches done\n",
      "130 / 218 batches done\n",
      "140 / 218 batches done\n",
      "150 / 218 batches done\n",
      "160 / 218 batches done\n",
      "170 / 218 batches done\n",
      "180 / 218 batches done\n",
      "190 / 218 batches done\n",
      "200 / 218 batches done\n",
      "210 / 218 batches done\n"
     ]
    }
   ],
   "source": [
    "# reset data loader\n",
    "val_dataloader.reset()\n",
    "# reset evaluation metric\n",
    "metric.reset()\n",
    "# loop over batches\n",
    "for nbatch, eval_batch in enumerate(val_dataloader):\n",
    "    # perform forward pass\n",
    "    mod.forward(eval_batch, is_train=False)\n",
    "    # get outputs\n",
    "    outputs=mod.get_outputs()\n",
    "    # update evaluation metric\n",
    "    metric.update(eval_batch.label,outputs)\n",
    "    # print progress\n",
    "    if nbatch%10==0:\n",
    "        print('{} / {} batches done'.format(nbatch,int(3500/batch_size)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Print results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean Intersection Over Union (mIOU): 0.819220680835\n"
     ]
    }
   ],
   "source": [
    "print(\"mean Intersection Over Union (mIOU): {}\".format(metric.get()[1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
